import numpy as np
import torch
import tqdm as 进度条


def 权重初始归一化(模型):
    类名 = 模型.__class__.__name__
    if 类名.find("'Conv2d'") != -1:
        torch.nn.init.normal_(模型.weight.data, 0.0, 0.02)
    elif 类名.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(模型.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(模型.bias.data, 0.0)


def 载入分类列表(路径):
    文件 = open(路径, "r")
    名称列表 = 文件.read().split("\n")[:-1]
    return 名称列表


def 恢复盒子列表正常比例(盒子列表, 当前维度, 原始的尺寸):
    原始的高, 原始的宽 = 原始的尺寸
    已填充的x = max(原始的高 - 原始的宽, 0) * (当前维度 / max(原始的尺寸))
    已填充的y = max(原始的宽 - 原始的高, 0) * (当前维度 / max(原始的尺寸))

    未填充的高 = 当前维度 - 已填充的y
    未填充的宽 = 当前维度 - 已填充的x

    盒子列表[:, 0] = ((盒子列表[:, 0] - 已填充的x // 2) / 未填充的宽) * 原始的宽
    盒子列表[:, 1] = ((盒子列表[:, 1] - 已填充的y // 2) / 未填充的高) * 原始的高
    盒子列表[:, 2] = ((盒子列表[:, 2] - 已填充的x // 2) / 未填充的宽) * 原始的宽
    盒子列表[:, 3] = ((盒子列表[:, 3] - 已填充的y // 2) / 未填充的高) * 原始的高

    return 盒子列表


def 到中央处理器(张量):
    return 张量.detach().cpu()


def 盒子边界宽和高的交并比(锚定盒, 目标盒子的宽和高):
    目标盒子的宽和高 = 目标盒子的宽和高.t()
    宽1, 高1 = 锚定盒[0], 锚定盒[1]
    宽2, 高2 = 目标盒子的宽和高[0], 目标盒子的宽和高[1]
    相交面积 = torch.min(宽1, 宽2) * torch.min(高1, 高2)
    相并面积 = (宽1 * 高1 + 1e-16) + 宽2 * 高2 - 相交面积
    return 相交面积 / 相并面积


def 坐标和宽高转坐标和坐标(输入):
    输出 = 输入.new(输入.shape)
    输出[..., 0] = 输入[..., 0] - 输入[..., 2] / 2
    输出[..., 1] = 输入[..., 1] - 输入[..., 3] / 2
    输出[..., 2] = 输入[..., 0] + 输入[..., 2] / 2
    输出[..., 3] = 输入[..., 1] + 输入[..., 3] / 2
    return 输出


def 计算每批分类平均精确度(有物体判断为有物体, 置信度列表, 预测的标签列表, 目标分类列表):
    索引列表 = np.argsort(-置信度列表)
    有物体判断为有物体, 置信度列表, 预测的标签列表 = 有物体判断为有物体[索引列表], 置信度列表[索引列表], 预测的标签列表[索引列表]
    唯一分类列表 = np.unique(目标分类列表)
    计数 = 0
    平均精确度, 精确度, 召回率 = [], [], []
    for 分类 in 进度条.tqdm(唯一分类列表, desc="正在计算平均精确度"):
        索引列表 = 预测的标签列表 == 分类
        有物体判断为有物体的数量 = (目标分类列表 == 分类).sum()
        预测有物体的数量 = 索引列表.sum()

        if 预测有物体的数量 == 0 and 有物体判断为有物体的数量 == 0:
            continue
        elif 预测有物体的数量 == 0 or 有物体判断为有物体的数量 == 0:
            平均精确度.append(0)
            召回率.append(0)
            精确度.append(0)
        else:
            有物体判断为无物体的累积和列表 = (1 - 有物体判断为有物体[索引列表]).cumsum()
            有物体判断为有物体的累积和列表 = (有物体判断为有物体[索引列表]).cumsum()

            召回率曲线 = 有物体判断为有物体的累积和列表 / (有物体判断为有物体的数量 + 1e16)
            召回率.append(召回率曲线[-1])

            精确度曲线 = 有物体判断为有物体的累积和列表 / (有物体判断为有物体的累积和列表 + 有物体判断为无物体的累积和列表)
            精确度.append(精确度曲线[-1])
            计数 += 1
            平均精确度.append(计算平均精确度(召回率曲线, 精确度曲线))

    平均精确度, 精确度, 召回率 = np.array(平均精确度), np.array(精确度), np.array(召回率)
    指标f1 = 2 * 精确度 * 召回率 / (精确度 + 召回率 + 1e-16)

    return 精确度, 召回率, 平均精确度, 指标f1, 唯一分类列表.astype("int32")


def 计算平均精确度(召回率列表, 精确度列表):
    召回率平均值列表 = np.concatenate(([0.0], 召回率列表, [1.0]))
    精确度平均值列表 = np.concatenate(([0.0], 精确度列表, [0.0]))

    for 索引 in range(精确度平均值列表.size - 1, 0, -1):
        精确度平均值列表[索引 - 1] = np.maximum(精确度平均值列表[索引 - 1], 精确度平均值列表[索引])

    索引 = np.where(召回率平均值列表[1:] != 召回率平均值列表[:-1])[0]

    精确度 = np.sum((召回率平均值列表[索引 + 1] - 召回率平均值列表[索引]) * 精确度平均值列表[索引 + 1])

    return 精确度


def 统计并获取某批的指标数据(输出列表, 目标列表, 交并比阈值):
    某批的指标列表 = []
    for 样本索引 in range(len(输出列表)):
        if 输出列表[样本索引] is None:
            continue
        输出 = 输出列表[样本索引]
        预测的盒子列表 = 输出[:, :4]
        预测的分数列表 = 输出[:, 4]
        预测的标签列表 = 输出[:, -1]

        正例判断为正例的列表 = np.zeros(预测的盒子列表.shape[0])

        注释列表 = 目标列表[目标列表[:, 0] == 样本索引][:, 1:]
        目标标签列表 = 注释列表[:, 0] if len(注释列表) else []
        if len(注释列表):
            已检查的盒子列表 = []
            目标盒子列表 = 注释列表[:, 1:]

            for 预测的索引, (预测的盒子, 预测的标签) in enumerate(zip(预测的盒子列表, 预测的标签列表)):
                if len(已检查的盒子列表) == len(目标标签列表):
                    break
                if 预测的标签 not in 目标标签列表:
                    continue

                交并比, 盒子的索引 = 盒子边界的交并比(预测的盒子.unsqueeze(0), 目标盒子列表).max(0)
                if 交并比 >= 交并比阈值 and 盒子的索引 not in 已检查的盒子列表:
                    正例判断为正例的列表[预测的索引] = 1
                    已检查的盒子列表 += [盒子的索引]
        某批的指标列表.append([正例判断为正例的列表, 预测的分数列表, 预测的标签列表])
    return 某批的指标列表


def 盒子边界的交并比(预测的盒子, 目标的盒子, 不转换坐标=True):
    if not 不转换坐标:
        预测的盒子_x1, 预测的盒子_x2 = 预测的盒子[:, 0] - 预测的盒子[:, 2] / 2, 预测的盒子[:, 0] + 预测的盒子[:, 2] / 2
        预测的盒子_y1, 预测的盒子_y2 = 预测的盒子[:, 1] - 预测的盒子[:, 3] / 2, 预测的盒子[:, 1] + 预测的盒子[:, 3] / 2
        目标的盒子_x1, 目标的盒子_x2 = 目标的盒子[:, 0] - 目标的盒子[:, 2] / 2, 目标的盒子[:, 0] + 目标的盒子[:, 2] / 2
        目标的盒子_y1, 目标的盒子_y2 = 目标的盒子[:, 1] - 目标的盒子[:, 3] / 2, 目标的盒子[:, 1] + 目标的盒子[:, 3] / 2
    else:
        预测的盒子_x1, 预测的盒子_y1, 预测的盒子_x2, 预测的盒子_y2 = 预测的盒子[:, 0], 预测的盒子[:, 1], 预测的盒子[:, 2], 预测的盒子[:, 3]
        目标的盒子_x1, 目标的盒子_y1, 目标的盒子_x2, 目标的盒子_y2 = 目标的盒子[:, 0], 目标的盒子[:, 1], 目标的盒子[:, 2], 目标的盒子[:, 3]

    相交区域_x1 = torch.max(预测的盒子_x1, 目标的盒子_x1)
    相交区域_y1 = torch.max(预测的盒子_y1, 目标的盒子_y1)
    相交区域_x2 = torch.min(预测的盒子_x2, 目标的盒子_x2)
    相交区域_y2 = torch.min(预测的盒子_y2, 目标的盒子_y2)
    相交面积 = torch.clamp(相交区域_x2 - 相交区域_x1 + 1, min=0) * torch.clamp(相交区域_y2 - 相交区域_y1 + 1, min=0)

    预测的盒子的面积 = (预测的盒子_x2 - 预测的盒子_x1 + 1) * (预测的盒子_y2 - 预测的盒子_y1 + 1)
    目标的盒子的面积 = (目标的盒子_x2 - 目标的盒子_x1 + 1) * (目标的盒子_y2 - 目标的盒子_y1 + 1)

    交并比 = 相交面积 / (预测的盒子的面积 + 目标的盒子的面积 - 相交面积 + 1e-16)

    return 交并比


def 非极大值抑制(预测的列表, 置信度阈值=0.5, 非极大值抑制阈值=0.4):
    """
    测试时由于损失值过大可能导致 while 检测的列表.size(0)无限循环
    :param 预测的列表:
    :param 置信度阈值:
    :param 非极大值抑制阈值:
    :return:
    """
    预测的列表[..., :4] = 坐标和宽高转坐标和坐标(预测的列表[..., :4])
    输出 = [None for _ in range(len(预测的列表))]
    for 图片索引, 预测的图片 in enumerate(预测的列表):
        预测的图片 = 预测的图片[预测的图片[:, 4] >= 置信度阈值]
        if not 预测的图片.size(0):
            continue

        分数 = 预测的图片[:, 4] * 预测的图片[:, 5:].max(1)[0]

        预测的图片 = 预测的图片[(-分数).argsort()]
        分类置信度列表, 预测的分类列表 = 预测的图片[:, 5:].max(1, keepdim=True)

        检测的列表 = torch.cat((预测的图片[:, :5], 分类置信度列表.float(), 预测的分类列表.float()), 1)
        保留的盒子列表 = []
        while 检测的列表.size(0):
            大的重叠区域 = 盒子边界的交并比(检测的列表[0, :4].unsqueeze(0), 检测的列表[:, :4]) > 非极大值抑制阈值
            匹配的标签 = 检测的列表[0, -1] == 检测的列表[:, -1]
            无效的 = 大的重叠区域 & 匹配的标签
            权重列表 = 检测的列表[无效的, 4:5]
            # 按置信度顺序合并重叠的盒子列表
            检测的列表[0, :4] = (权重列表 * 检测的列表[无效的, :4]).sum(0) / 权重列表.sum()
            保留的盒子列表 += [检测的列表[0]]
            检测的列表 = 检测的列表[~无效的]

        if 保留的盒子列表:
            输出[图片索引] = torch.stack(保留的盒子列表)
    return 输出


def 构建目标列表(预测的盒子列表, 预测的分类列表, 目标列表, 锚定盒列表, 忽略用阈值):
    浮点型张量 = torch.cuda.FloatTensor if 预测的盒子列表.is_cuda else torch.FloatTensor
    布尔型张量 = torch.cuda.BoolTensor if 预测的盒子列表.is_cuda else torch.BoolTensor

    每批图片数量 = 预测的盒子列表.size(0)
    锚定盒数量 = 预测的盒子列表.size(1)
    类别数量 = 预测的分类列表.size(-1)
    预测的盒子尺寸 = 预测的盒子列表.size(2)

    有物体的掩码列表 = 布尔型张量(每批图片数量, 锚定盒数量, 预测的盒子尺寸, 预测的盒子尺寸).fill_(0)
    无物体的掩码列表 = 布尔型张量(每批图片数量, 锚定盒数量, 预测的盒子尺寸, 预测的盒子尺寸).fill_(1)
    分类的掩码列表 = 浮点型张量(每批图片数量, 锚定盒数量, 预测的盒子尺寸, 预测的盒子尺寸).fill_(0)
    交并比分数列表 = 浮点型张量(每批图片数量, 锚定盒数量, 预测的盒子尺寸, 预测的盒子尺寸).fill_(0)

    预设目标中x的列表 = 浮点型张量(每批图片数量, 锚定盒数量, 预测的盒子尺寸, 预测的盒子尺寸).fill_(0)
    预设目标中y的列表 = 浮点型张量(每批图片数量, 锚定盒数量, 预测的盒子尺寸, 预测的盒子尺寸).fill_(0)
    预设目标中宽的列表 = 浮点型张量(每批图片数量, 锚定盒数量, 预测的盒子尺寸, 预测的盒子尺寸).fill_(0)
    预设目标中高的列表 = 浮点型张量(每批图片数量, 锚定盒数量, 预测的盒子尺寸, 预测的盒子尺寸).fill_(0)
    预设目标分类的列表 = 浮点型张量(每批图片数量, 锚定盒数量, 预测的盒子尺寸, 预测的盒子尺寸, 类别数量).fill_(0)

    目标盒子列表 = 目标列表[:, 2:6] * 预测的盒子尺寸
    目标盒子的中心点列表 = 目标盒子列表[:, :2]
    目标盒子的宽和高列表 = 目标盒子列表[:, 2:]

    交并比列表 = torch.stack([盒子边界宽和高的交并比(锚定盒, 目标盒子的宽和高列表) for 锚定盒 in 锚定盒列表])
    # print("交并比列表", 交并比列表.shape)
    最佳_交并比列表, 最佳_交并比索引列表 = 交并比列表.max(0)

    图片索引列表, 目标标签列表 = 目标列表[:, :2].long().t()
    目标盒子x列表, 目标盒子y列表 = 目标盒子的中心点列表.t()
    目标盒子宽列表, 目标盒子高列表 = 目标盒子的宽和高列表.t()
    # 向下取整
    目标盒子y取整列表, 目标盒子x取整列表 = 目标盒子的中心点列表.long().t()

    有物体的掩码列表[图片索引列表, 最佳_交并比索引列表, 目标盒子x取整列表, 目标盒子y取整列表] = 1
    无物体的掩码列表[图片索引列表, 最佳_交并比索引列表, 目标盒子x取整列表, 目标盒子y取整列表] = 0

    for 索引, 交并比 in enumerate(交并比列表.t()):
        无物体的掩码列表[图片索引列表[索引], 交并比 > 忽略用阈值, 目标盒子x取整列表[索引], 目标盒子y取整列表[索引]] = 0

    # 获取网格中的位置
    预设目标中x的列表[图片索引列表, 最佳_交并比索引列表, 目标盒子x取整列表, 目标盒子y取整列表] = 目标盒子x列表 - 目标盒子x列表.floor()
    预设目标中y的列表[图片索引列表, 最佳_交并比索引列表, 目标盒子x取整列表, 目标盒子y取整列表] = 目标盒子y列表 - 目标盒子y列表.floor()
    预设目标中宽的列表[图片索引列表, 最佳_交并比索引列表, 目标盒子x取整列表, 目标盒子y取整列表] = torch.log(目标盒子宽列表 / 锚定盒列表[最佳_交并比索引列表][:, 0] + 1e-16)
    预设目标中高的列表[图片索引列表, 最佳_交并比索引列表, 目标盒子x取整列表, 目标盒子y取整列表] = torch.log(目标盒子高列表 / 锚定盒列表[最佳_交并比索引列表][:, 1] + 1e-16)
    预设目标分类的列表[图片索引列表, 最佳_交并比索引列表, 目标盒子x取整列表, 目标盒子y取整列表, 目标标签列表] = 1

    分类的掩码列表[图片索引列表, 最佳_交并比索引列表, 目标盒子x取整列表, 目标盒子y取整列表] = (
            预测的分类列表[图片索引列表, 最佳_交并比索引列表, 目标盒子x取整列表, 目标盒子y取整列表].argmax(-1) == 目标标签列表).float()
    交并比分数列表[图片索引列表, 最佳_交并比索引列表, 目标盒子x取整列表, 目标盒子y取整列表] = 盒子边界的交并比(
        预测的盒子列表[图片索引列表, 最佳_交并比索引列表, 目标盒子x取整列表, 目标盒子y取整列表], 目标盒子列表, 不转换坐标=False)

    目标置信度列表 = 有物体的掩码列表.float()
    return 交并比分数列表, 分类的掩码列表, 有物体的掩码列表, 无物体的掩码列表, 预设目标中x的列表, 预设目标中y的列表, 预设目标中宽的列表, 预设目标中高的列表, 预设目标分类的列表, 目标置信度列表
